import importlib.util
from pathlib import Path

import numpy as np
import pandas as pd
import statsmodels.api as sm


def load_shiftshare_module(root: Path):
    path = root / "analysis" / "19_iv_shiftshare_rollout_absorb.py"
    spec = importlib.util.spec_from_file_location("shiftshare", path)
    mod = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(mod)
    return mod


def fit_ols_absorbed(mod, df: pd.DataFrame, y: str, controls: list[str]):
    cols = [y, "netusoft", "edu_high", "ct", "region", *controls]
    cols += ["z_ss_cov30", "z_ss_cov100", "z_ss_fttp"]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]

    model_cols = [y, "netusoft", "netusoft_x_edu", *controls, "edu_high"]
    r = mod.absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y_r = r[y]
    X = r[controls + ["edu_high", "netusoft", "netusoft_x_edu"]]
    X = sm.add_constant(X, has_constant="add")
    res = sm.OLS(y_r, X).fit(cov_type="cluster", cov_kwds={"groups": d["region"]})
    return res


def summarize_generic(res, label: str) -> dict:
    b1 = float(res.params.get("netusoft", np.nan))
    b2 = float(res.params.get("netusoft_x_edu", np.nan))
    se1 = float(res.std_errors.get("netusoft", np.nan)) if hasattr(res, "std_errors") else float(res.bse.get("netusoft", np.nan))
    se2 = (
        float(res.std_errors.get("netusoft_x_edu", np.nan))
        if hasattr(res, "std_errors")
        else float(res.bse.get("netusoft_x_edu", np.nan))
    )
    return {
        "model": label,
        "nobs": int(getattr(res, "nobs", np.nan)),
        "coef_lowEdu": b1,
        "se_lowEdu": se1,
        "coef_highEdu": b1 + b2,
        "coef_deltaHighMinusLow": b2,
        "se_deltaHighMinusLow": se2,
    }


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)

    mod = load_shiftshare_module(root)
    df = mod.build_dataset(root)

    outcomes = [
        ("y_vote", "Vote (yes/no)"),
        ("y_part_index", "Participation index"),
        ("y_sgnptit", "Signed petition"),
        ("y_contplt", "Contacted politician/official"),
        ("y_badge", "Worn campaign badge/sticker"),
        ("y_bctprd", "Boycotted products"),
    ]
    controls = ["agea", "gndr", "hinctnta"]

    rows = []
    for y, ylab in outcomes:
        # OLS
        ols = fit_ols_absorbed(mod, df, y, controls=controls)
        r_ols = summarize_generic(ols, "OLS (absorbed FE)")
        r_ols.update({"outcome": y, "label": ylab})
        rows.append(r_ols)

        # 2SLS
        iv = mod.fit_iv_interaction_absorb(df, y, controls=controls)
        r_iv = summarize_generic(iv, "2SLS (shift-share IV)")
        r_iv.update(
            {
                "outcome": y,
                "label": ylab,
                "fsF_netusoft": float(iv.first_stage.diagnostics.loc["netusoft", "f.stat"]),
                "fsF_netusoft_x_edu": float(iv.first_stage.diagnostics.loc["netusoft_x_edu", "f.stat"]),
                "AR_pvalue": float(getattr(iv.anderson_rubin, "pval", np.nan)),
            }
        )
        rows.append(r_iv)

        # LIML
        liml = mod.fit_liml_interaction_absorb(df, y, controls=controls)
        r_liml = summarize_generic(liml, "LIML (shift-share IV)")
        r_liml.update(
            {
                "outcome": y,
                "label": ylab,
                "fsF_netusoft": float(liml.first_stage.diagnostics.loc["netusoft", "f.stat"]),
                "fsF_netusoft_x_edu": float(liml.first_stage.diagnostics.loc["netusoft_x_edu", "f.stat"]),
                "AR_pvalue": float(getattr(liml.anderson_rubin, "pval", np.nan)),
            }
        )
        rows.append(r_liml)

    tab = pd.DataFrame(rows)
    tab = tab[
        [
            "label",
            "outcome",
            "model",
            "nobs",
            "coef_lowEdu",
            "se_lowEdu",
            "coef_highEdu",
            "coef_deltaHighMinusLow",
            "se_deltaHighMinusLow",
            "fsF_netusoft",
            "fsF_netusoft_x_edu",
            "AR_pvalue",
        ]
    ].copy()

    (out_dir / "iv_shiftshare_ols_vs_iv.csv").write_text(tab.to_csv(index=False), encoding="utf-8")
    # Minimal pretty markdown
    pretty = tab.copy()
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        pretty[c] = pretty[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        pretty[c] = pretty[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    pretty["AR_pvalue"] = pretty["AR_pvalue"].map(lambda x: f"{x:.4g}" if pd.notna(x) else "")
    (out_dir / "iv_shiftshare_ols_vs_iv.md").write_text(pretty.to_markdown(index=False) + "\n", encoding="utf-8")
    print(f"Wrote: {out_dir / 'iv_shiftshare_ols_vs_iv.md'}")


if __name__ == "__main__":
    main()

